import os, wandb, collections
from src.utils import *
from src.configs.config import default_argument_parser, setup

os.environ["WANDB__SERVICE_WAIT"] = "300"

# main fucntion 
if __name__ == "__main__":

    print("starting the main process")
    args = default_argument_parser().parse_args()
    cfg = setup(args)
    print("got cfg")

    if cfg.wandb:
        wandb.login()

    local_rank = cfg.local_rank

    assert cfg.action in [
        "train_classifier",
        "train_generator"
        ]

    trash_path, log_path, img_path, model_path, group_log_path = generate_path(cfg)
    print("path generated")

    device = f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu"
    cfg.device = device

    LOG_FORMAT = "%(asctime)s | %(name)s " + f"| {device} " + "| %(message)s" 
    logging.basicConfig(filename=log_path, level=logging.INFO, format=LOG_FORMAT)
    
    logger = logging.getLogger(module_structure(__file__))
    # logger = logging.getLogger()
    logger.info(f"Happy logging :)")

    write_args(args, logger)

    logger.info(f"trash path: {trash_path}")
    logger.info(f"log path: {log_path}")
    logger.info(f"image path: {img_path}")
    logger.info(f"model path: {model_path}")

    cfg.pathes.trash_path = trash_path
    cfg.pathes.log_path = log_path
    cfg.pathes.img_path = img_path
    cfg.pathes.model_path = model_path

    if cfg.ddp:
        establish_communication(device = local_rank)
        print("communication established")

    if cfg.action == "train_classifier":
        logger.info(f"into train classifier")
        from src.trainer.train_classifier import train
        cfg.Classifier.path = os.path.join(cfg.pathes.model_path, "classifier.pt")
        if cfg.wandb:
            wandb.init(
                # set the wandb project where this run will be logged
                project="c2g_c_autodl",
                
                # track hyperparameters and run metadata
                config=cfg.to_dict(),
                tags = [cfg.Data.Name, cfg.Classifier.Name, cfg.action],
            )
        results_dict = train(cfg)
        if cfg.wandb:
            wandb.finish()

    elif cfg.action == "train_generator":
        logger.info(f"into train generator")
        from src.trainer.train_generator import train
        cfg.Generator.path = os.path.join(cfg.pathes.model_path, "generator.pt")
        if os.path.exists(cfg.Generator.path):
            print("Already finished")
            results_dict = {}
        else:
            cfg.muer.path = os.path.join(cfg.pathes.model_path, "muer.pt")
            cfg.alpha.path = os.path.join(cfg.pathes.model_path, "alpha.pt")
            if cfg.wandb:
                wandb.init(
                    # set the wandb project where this run will be logged
                    project="c2g_g_autodl",
                    
                    # track hyperparameters and run metadata
                    config=cfg.to_dict(),
                    tags = [cfg.Data.Name, cfg.Generator.Name, cfg.action],
                )
            results_dict = train(cfg)
            if cfg.wandb:
                wandb.finish()

    if results_dict:
        group_log_dic = dict(zip(args.opts[::2], args.opts[1::2]))
        group_log_dic = collections.OrderedDict(sorted(group_log_dic.items()))
        group_log_dic.update(results_dict)
        write_group_log(group_log_dic, group_log_path)

    try:
        logger.info(f"\n{gpu_summary(device)}")
    except Exception as e:
        print(e)
    
    print("done")
